{
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "##### Copyright 2021 The TensorFlow Authors."
      ],
      "metadata": {
        "id": "Tce3stUlHN0L"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ],
      "outputs": [],
      "metadata": {
        "cellView": "form",
        "id": "tuOe1ymfHZPu"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "# BERT Preprocessing with TF Text"
      ],
      "metadata": {
        "id": "qFdPvlXBOdUN"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/text/guide/bert_preprocessing_guide\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/text/blob/master/docs/guide/bert_preprocessing_guide.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/text/blob/master/docs/guide/bert_preprocessing_guide.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/text/docs/guide/bert_preprocessing_guide.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "</table>"
      ],
      "metadata": {
        "id": "MfBg1C5NB3X0"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Overview\n",
        "\n",
        "Text preprocessing is the end-to-end transformation of raw text into a model’s integer inputs. NLP models are often accompanied by several hundreds (if not thousands) of lines of Python code for preprocessing text. Text preprocessing is often a challenge for models because:\n",
        "\n",
        "* **Training-serving skew.** It becomes increasingly difficult to ensure that the preprocessing logic of the model's inputs are consistent at all stages of model development (e.g. pretraining, fine-tuning, evaluation, inference). \n",
        "Using different hyperparameters, tokenization, string preprocessing algorithms or simply packaging model inputs inconsistently at different stages could yield hard-to-debug and disastrous effects to the model. \n",
        "\n",
        "* **Efficiency and flexibility.** While preprocessing can be done offline (e.g. by writing out processed outputs to files on disk and then reconsuming said preprocessed data in the input pipeline), this method incurs an additional file read and write cost. Preprocessing offline is also inconvenient if there are preprocessing decisions that need to happen dynamically. Experimenting with a different option would require regenerating the dataset again.\n",
        "\n",
        "* **Complex model interface.** Text models are much more understandable when their inputs are pure text. It's hard to understand a model when its inputs require an extra, indirect encoding step. Reducing the preprocessing complexity is especially appreciated for model debugging, serving, and evaluation. \n",
        "\n",
        "Additionally, simpler model interfaces also make it more convenient to try the model (e.g. inference or training) on different, unexplored datasets.\n"
      ],
      "metadata": {
        "id": "xHxb-dlhMIzW"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Text preprocessing with TF.Text\n",
        "\n",
        "Using TF.Text's text preprocessing APIs, we can construct a preprocessing\n",
        "function that can transform a user's text dataset into the model's\n",
        "integer inputs. Users can package preprocessing directly as part of their model to alleviate the above mentioned problems.\n",
        "\n",
        "This tutorial will show how to use TF.Text preprocessing ops to transform text data into inputs for the BERT model and inputs for language masking pretraining task described in \"Masked LM and Masking Procedure\" of [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/pdf/1810.04805.pdf). The process involves tokenizing text into subword units, combining sentences, trimming content to a fixed size and extracting labels for the masked language modeling task."
      ],
      "metadata": {
        "id": "Y6DTHtXbxPgw"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Setup"
      ],
      "metadata": {
        "id": "MUXex9ctTuDB"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Let's import the packages and libraries we need first."
      ],
      "metadata": {
        "id": "pmIjNKsfeTpm"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "!pip install -q -U tensorflow-text"
      ],
      "outputs": [],
      "metadata": {
        "id": "gTWQ5swI7FRJ"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "import tensorflow as tf\n",
        "import tensorflow_text as text\n",
        "import functools"
      ],
      "outputs": [],
      "metadata": {
        "id": "IqR2PQG4ZaZ0"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Our data contains two text features and we can create a example `tf.data.Dataset`. Our goal is to create a function that we can supply `Dataset.map()` with to be used in training."
      ],
      "metadata": {
        "id": "-brDHSrRaMii"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "examples = {\n",
        "    \"text_a\": [\n",
        "      b\"Sponge bob Squarepants is an Avenger\",\n",
        "      b\"Marvel Avengers\"\n",
        "    ],\n",
        "    \"text_b\": [\n",
        "     b\"Barack Obama is the President.\",\n",
        "     b\"President is the highest office\"\n",
        "  ],\n",
        "}\n",
        "\n",
        "dataset = tf.data.Dataset.from_tensor_slices(examples)\n",
        "next(iter(dataset))"
      ],
      "outputs": [],
      "metadata": {
        "id": "DQyj7OQ9yk7K"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Tokenizing\n",
        "\n",
        "Our first step is to run any string preprocessing and tokenize our dataset. This can be done using the [`text.BertTokenizer`](https://tensorflow.org/text/api_docs/python/text/BertTokenizer), which is a [`text.Splitter`](https://tensorflow.org/text/api_docs/python/text/Splitter) that can tokenize sentences into subwords or wordpieces for the [BERT model](https://github.com/google-research/bert) given a vocabulary generated from the [Wordpiece algorithm](https://www.tensorflow.org/text/guide/subwords_tokenizer#optional_the_algorithm). You can learn more about other subword tokenizers available in TF.Text from [here](https://www.tensorflow.org/text/guide/subwords_tokenizer). \n",
        "\n",
        "\n",
        "The vocabulary can be from a previously generated BERT checkpoint, or you can generate one yourself on your own data. For the purposes of this example, let's create a toy vocabulary:"
      ],
      "metadata": {
        "id": "1laUIs3g5Qsz"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "_VOCAB = [\n",
        "    # Special tokens\n",
        "    b\"[UNK]\", b\"[MASK]\", b\"[RANDOM]\", b\"[CLS]\", b\"[SEP]\",\n",
        "    # Suffixes\n",
        "    b\"##ack\", b\"##ama\", b\"##ger\", b\"##gers\", b\"##onge\", b\"##pants\",  b\"##uare\",\n",
        "    b\"##vel\", b\"##ven\", b\"an\", b\"A\", b\"Bar\", b\"Hates\", b\"Mar\", b\"Ob\",\n",
        "    b\"Patrick\", b\"President\", b\"Sp\", b\"Sq\", b\"bob\", b\"box\", b\"has\", b\"highest\",\n",
        "    b\"is\", b\"office\", b\"the\",\n",
        "]\n",
        "\n",
        "_START_TOKEN = _VOCAB.index(b\"[CLS]\")\n",
        "_END_TOKEN = _VOCAB.index(b\"[SEP]\")\n",
        "_MASK_TOKEN = _VOCAB.index(b\"[MASK]\")\n",
        "_RANDOM_TOKEN = _VOCAB.index(b\"[RANDOM]\")\n",
        "_UNK_TOKEN = _VOCAB.index(b\"[UNK]\")\n",
        "_MAX_SEQ_LEN = 8\n",
        "_MAX_PREDICTIONS_PER_BATCH = 5\n",
        " \n",
        "_VOCAB_SIZE = len(_VOCAB)\n",
        "\n",
        "lookup_table = tf.lookup.StaticVocabularyTable(\n",
        "    tf.lookup.KeyValueTensorInitializer(\n",
        "      keys=_VOCAB,\n",
        "      key_dtype=tf.string,\n",
        "      values=tf.range(\n",
        "          tf.size(_VOCAB, out_type=tf.int64), dtype=tf.int64),\n",
        "      value_dtype=tf.int64),\n",
        "      num_oov_buckets=1\n",
        ")"
      ],
      "outputs": [],
      "metadata": {
        "id": "ChpIFy515S1z"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Let's construct a [`text.BertTokenizer`](https://tensorflow.org/text/api_docs/python/text/BertTokenizer) using the above vocabulary and tokenize the text inputs into a [`RaggedTensor`](https://www.tensorflow.org/api_docs/python/tf/RaggedTensor).`."
      ],
      "metadata": {
        "id": "7t2tgbSn6nvX"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "bert_tokenizer = text.BertTokenizer(lookup_table, token_out_type=tf.string)\n",
        "bert_tokenizer.tokenize(examples[\"text_a\"])"
      ],
      "outputs": [],
      "metadata": {
        "id": "564UPrFB5Zm6"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "bert_tokenizer.tokenize(examples[\"text_b\"])"
      ],
      "outputs": [],
      "metadata": {
        "id": "AiTs3_FHHBlR"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Text output from [`text.BertTokenizer`](https://tensorflow.org/text/api_docs/python/text/BertTokenizer) allows us see how the text is being tokenized, but the model requires integer IDs. We can set the `token_out_type` param to `tf.int64` to obtain integer IDs (which are the indices into the vocabulary)."
      ],
      "metadata": {
        "id": "cK6DHjio65MV"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "bert_tokenizer = text.BertTokenizer(lookup_table, token_out_type=tf.int64)\n",
        "segment_a = bert_tokenizer.tokenize(examples[\"text_a\"])\n",
        "segment_a"
      ],
      "outputs": [],
      "metadata": {
        "id": "odeosiPz58Qu"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "segment_b = bert_tokenizer.tokenize(examples[\"text_b\"])\n",
        "segment_b"
      ],
      "outputs": [],
      "metadata": {
        "id": "v4IP2P4EHQpa"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "[`text.BertTokenizer`](https://tensorflow.org/text/api_docs/python/text/BertTokenizer) returns a `RaggedTensor` with shape `[batch, num_tokens, num_wordpieces]`. Because we don't need the extra `num_tokens` dimensions for our current use case,  we can merge the last two dimensions to obtain a `RaggedTensor` with shape `[batch, num_wordpieces]`:"
      ],
      "metadata": {
        "id": "TU3GJ0jx94fx"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "segment_a = segment_a.merge_dims(-2, -1)\n",
        "segment_a"
      ],
      "outputs": [],
      "metadata": {
        "id": "Fb5vt5dA-Rwf"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "segment_b = segment_b.merge_dims(-2, -1)\n",
        "segment_b"
      ],
      "outputs": [],
      "metadata": {
        "id": "NyEW0sjhHoPM"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Content Trimming\n",
        "\n",
        "The main input to BERT is a concatenation of two sentences. However, BERT requires inputs to be in a fixed-size and shape and we may have content which exceed our budget. \n",
        "\n",
        "We can tackle this by using a [`text.Trimmer`](https://tensorflow.org/text/api_docs/python/text/Trimmer) to trim our content down to a predetermined size (once concatenated along the last axis). There are different `text.Trimmer` types which select content to preserve using different algorithms. [`text.RoundRobinTrimmer`](https://tensorflow.org/text/api_docs/python/text/RoundRobinTrimmer) for example will allocate quota equally for each segment but may trim the ends of sentences. [`text.WaterfallTrimmer`](https://tensorflow.org/text/api_docs/python/text/WaterfallTrimmer) will trim starting from the end of the last sentence.\n",
        "\n",
        "For our example, we will use `RoundRobinTrimmer` which selects items from each segment in a left-to-right manner.\n"
      ],
      "metadata": {
        "id": "R9YicLN5UFkz"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "trimmer = text.RoundRobinTrimmer(max_seq_length=[_MAX_SEQ_LEN])\n",
        "trimmed = trimmer.trim([segment_a, segment_b])\n",
        "trimmed"
      ],
      "outputs": [],
      "metadata": {
        "id": "aLV-1uDgwFnr"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "`trimmed` now contains the segments where the number of elements across a batch is 8 elements (when concatenated along axis=-1)."
      ],
      "metadata": {
        "id": "zPj7jM9oQ-P3"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Combining segments\n",
        "\n",
        "Now that we have segments trimmed, we can combine them together to get a single `RaggedTensor`. BERT uses special tokens to indicate the beginning (`[CLS]`) and end of a segment (`[SEP]`). We also need a `RaggedTensor` indicating which items in the combined `Tensor` belong to which segment. We can use [`text.combine_segments()`](https://tensorflow.org/text/api_docs/python/text/combine_segments) to get both of these `Tensor` with special tokens inserted."
      ],
      "metadata": {
        "id": "3J2AWfmAUio8"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "segments_combined, segments_ids = text.combine_segments(\n",
        "  [segment_a, segment_b],\n",
        "  start_of_sequence_id=_START_TOKEN, end_of_segment_id=_END_TOKEN)\n",
        "segments_combined, segments_ids"
      ],
      "outputs": [],
      "metadata": {
        "id": "L-5nMh5pk8x1"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Masked Language Model Task\n",
        "\n",
        "Now that we have our basic inputs, we can begin to extract the inputs needed for the \"Masked LM and Masking Procedure\" task described in [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/pdf/1810.04805.pdf)\n",
        "\n",
        "The masked language model task has two sub-problems for us to think about: (1) what items to select for masking and (2) what values are they assigned? \n"
      ],
      "metadata": {
        "id": "hSKla2OxUOWl"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "#### Item Selection\n",
        "Because we will choose to select items randomly for masking, we will use a [`text.RandomItemSelector`](https://tensorflow.org/text/api_docs/python/text/RandomItemSelector). `RandomItemSelector` randomly selects items in a batch subject to restrictions given (`max_selections_per_batch`, `selection_rate` and `unselectable_ids`) and returns a boolean mask indicating which items were selected."
      ],
      "metadata": {
        "id": "mkx4w9-3DT0p"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "random_selector = text.RandomItemSelector(\n",
        "    max_selections_per_batch=_MAX_PREDICTIONS_PER_BATCH,\n",
        "    selection_rate=0.2,\n",
        "    unselectable_ids=[_START_TOKEN, _END_TOKEN, _UNK_TOKEN]\n",
        ")\n",
        "selected = random_selector.get_selection_mask(\n",
        "    segments_combined, axis=1)\n",
        "selected\n"
      ],
      "outputs": [],
      "metadata": {
        "id": "94BncqVkVJT2"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "#### Choosing the Masked Value\n",
        "\n",
        "The methodology described the original BERT paper for choosing the value for masking is as follows:\n",
        "\n",
        "For `mask_token_rate` of the time, replace the item with the `[MASK]` token:\n",
        "\n",
        "    \"my dog is hairy\" -> \"my dog is [MASK]\"\n",
        " \n",
        "For `random_token_rate` of the time, replace the item with a random word:\n",
        "\n",
        "    \"my dog is hairy\" -> \"my dog is apple\"\n",
        " \n",
        "For `1 - mask_token_rate - random_token_rate` of the time, keep the item\n",
        "unchanged:\n",
        "\n",
        "    \"my dog is hairy\" -> \"my dog is hairy.\"\n",
        "\n",
        "[`text.MaskedValuesChooser`](https://tensorflow.org/text/api_docs/python/text/MaskValuesChooser) encapsulates this logic and can be used for our preprocessing function. Here's an example of what `MaskValuesChooser` returns given a `mask_token_rate` of 80% and default `random_token_rate`:\n"
      ],
      "metadata": {
        "id": "p4NAHL_GUi-C"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "input_ids = tf.ragged.constant([[19, 7, 21, 20, 9, 8], [13, 4, 16, 5], [15, 10, 12, 11, 6]])\n",
        "mask_values_chooser = text.MaskValuesChooser(_VOCAB_SIZE, _MASK_TOKEN, 0.8)\n",
        "mask_values_chooser.get_mask_values(input_ids)"
      ],
      "outputs": [],
      "metadata": {
        "id": "Amk0Lqd5VJ4n"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "When supplied with a `RaggedTensor` input, `text.MaskValuesChooser` returns a `RaggedTensor` of the same shape with either `_MASK_VALUE` (0), a random ID, or the same unchanged id."
      ],
      "metadata": {
        "id": "UCp1CQcPC6IT"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "#### Generating Inputs for Masked Language Model Task\n",
        "\n",
        "Now that we have a `RandomItemSelector` to help us select items for masking and `text.MaskValuesChooser` to assign the values, we can use [`text.mask_language_model()`](https://tensorflow.org/text/api_docs/python/text/mask_language_model) to assemble all the inputs of this task for our BERT model.\n"
      ],
      "metadata": {
        "id": "EYpKg_sLUi1B"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "masked_token_ids, masked_pos, masked_lm_ids = text.mask_language_model(\n",
        "  segments_combined,\n",
        "  item_selector=random_selector, mask_values_chooser=mask_values_chooser)"
      ],
      "outputs": [],
      "metadata": {
        "id": "Q0fqQzXGUrkM"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Let's dive deeper and examine the outputs of `mask_language_model()`. The output of `masked_token_ids` is:"
      ],
      "metadata": {
        "id": "pJqcbOJ0AYBX"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "masked_token_ids"
      ],
      "outputs": [],
      "metadata": {
        "id": "PavYEhmN_tHa"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Remember that our input is encoded using a vocabulary. If we decode `masked_token_ids` using our vocabulary, we get:"
      ],
      "metadata": {
        "id": "Q0c2wkC9AnUX"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "tf.gather(_VOCAB, masked_token_ids)"
      ],
      "outputs": [],
      "metadata": {
        "id": "5axqrUOc_0h1"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Notice that some wordpiece tokens have been replaced with either `[MASK]`, `[RANDOM]` or a different ID value. `masked_pos` output gives us the indices (in the respective batch) of the tokens that have been replaced."
      ],
      "metadata": {
        "id": "v8DCOtEAiz_E"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "masked_pos"
      ],
      "outputs": [],
      "metadata": {
        "id": "d-nc5m5Y_wP_"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "`masked_lm_ids` gives us the original value of the token."
      ],
      "metadata": {
        "id": "6fua7ANijN3_"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        " masked_lm_ids"
      ],
      "outputs": [],
      "metadata": {
        "id": "azzxmO_f_xJp"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "We can again decode the IDs here to get human readable values."
      ],
      "metadata": {
        "id": "5bW0rdX9jYh-"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "tf.gather(_VOCAB, masked_lm_ids)"
      ],
      "outputs": [],
      "metadata": {
        "id": "F-RP-paUjUuP"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Padding Model Inputs\n",
        "\n",
        "Now that we have all the inputs for our model, the last step in our preprocessing is to package them into fixed 2-dimensional `Tensor`s with padding and also generate a mask `Tensor` indicating the values which are pad values. We can use [`text.pad_model_inputs()`](https://tensorflow.org/text/api_docs/python/text/pad_model_inputs) to help us with this task."
      ],
      "metadata": {
        "id": "3-P0PTiCUz2J"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "# Prepare and pad combined segment inputs\n",
        "input_word_ids, input_mask = text.pad_model_inputs(\n",
        "  masked_token_ids, max_seq_length=_MAX_SEQ_LEN)\n",
        "input_type_ids, _ = text.pad_model_inputs(\n",
        "  masked_token_ids, max_seq_length=_MAX_SEQ_LEN)\n",
        "\n",
        "# Prepare and pad masking task inputs\n",
        "masked_lm_positions, masked_lm_weights = text.pad_model_inputs(\n",
        "  masked_token_ids, max_seq_length=_MAX_PREDICTIONS_PER_BATCH)\n",
        "masked_lm_ids, _ = text.pad_model_inputs(\n",
        "  masked_lm_ids, max_seq_length=_MAX_PREDICTIONS_PER_BATCH)\n",
        "\n",
        "model_inputs = {\n",
        "    \"input_word_ids\": input_word_ids,\n",
        "    \"input_mask\": input_mask,\n",
        "    \"input_type_ids\": input_type_ids,\n",
        "    \"masked_lm_ids\": masked_lm_ids,\n",
        "    \"masked_lm_positions\": masked_lm_positions,\n",
        "    \"masked_lm_weights\": masked_lm_weights,\n",
        "}\n",
        "model_inputs"
      ],
      "outputs": [],
      "metadata": {
        "id": "FGE7XuXRwsYF"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Review"
      ],
      "metadata": {
        "id": "KIWy4nVyT6gf"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Let's review what we have so far and assemble our preprocessing function. Here's what we have:"
      ],
      "metadata": {
        "id": "TwCdO1Z5yjS-"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "def bert_pretrain_preprocess(vocab_table, features):\n",
        "  # Input is a string Tensor of documents, shape [batch, 1].\n",
        "  text_a = features[\"text_a\"]\n",
        "  text_b = features[\"text_b\"]\n",
        "\n",
        "  # Tokenize segments to shape [num_sentences, (num_words)] each.\n",
        "  tokenizer = text.BertTokenizer(\n",
        "      vocab_table,\n",
        "      token_out_type=tf.int64)\n",
        "  segments = [tokenizer.tokenize(text).merge_dims(\n",
        "      1, -1) for text in (text_a, text_b)]\n",
        "\n",
        "  # Truncate inputs to a maximum length.\n",
        "  trimmer = text.RoundRobinTrimmer(max_seq_length=6)\n",
        "  trimmed_segments = trimmer.trim(segments)\n",
        "\n",
        "  # Combine segments, get segment ids and add special tokens.\n",
        "  segments_combined, segment_ids = text.combine_segments(\n",
        "      trimmed_segments,\n",
        "      start_of_sequence_id=_START_TOKEN,\n",
        "      end_of_segment_id=_END_TOKEN)\n",
        "\n",
        "  # Apply dynamic masking task.\n",
        "  masked_input_ids, masked_lm_positions, masked_lm_ids = (\n",
        "      text.mask_language_model(\n",
        "        segments_combined,\n",
        "        random_selector,\n",
        "        mask_values_chooser,\n",
        "      )\n",
        "  )\n",
        "  \n",
        "  # Prepare and pad combined segment inputs\n",
        "  input_word_ids, input_mask = text.pad_model_inputs(\n",
        "    masked_input_ids, max_seq_length=_MAX_SEQ_LEN)\n",
        "  input_type_ids, _ = text.pad_model_inputs(\n",
        "    masked_input_ids, max_seq_length=_MAX_SEQ_LEN)\n",
        "\n",
        "  # Prepare and pad masking task inputs\n",
        "  masked_lm_positions, masked_lm_weights = text.pad_model_inputs(\n",
        "    masked_input_ids, max_seq_length=_MAX_PREDICTIONS_PER_BATCH)\n",
        "  masked_lm_ids, _ = text.pad_model_inputs(\n",
        "    masked_lm_ids, max_seq_length=_MAX_PREDICTIONS_PER_BATCH)\n",
        "\n",
        "  model_inputs = {\n",
        "      \"input_word_ids\": input_word_ids,\n",
        "      \"input_mask\": input_mask,\n",
        "      \"input_type_ids\": input_type_ids,\n",
        "      \"masked_lm_ids\": masked_lm_ids,\n",
        "      \"masked_lm_positions\": masked_lm_positions,\n",
        "      \"masked_lm_weights\": masked_lm_weights,\n",
        "  }\n",
        "  return model_inputs"
      ],
      "outputs": [],
      "metadata": {
        "id": "7jKtbVCYTsIC"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "We previously constructed a `tf.data.Dataset` and we can now use our assembled preprocessing function `bert_pretrain_preprocess()` in `Dataset.map()`. This allows us to create an input pipeline for transforming our raw string data into integer inputs and feed directly into our model."
      ],
      "metadata": {
        "id": "bOAeo97VyfQg"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "dataset = tf.data.Dataset.from_tensors(examples)\n",
        "dataset = dataset.map(functools.partial(\n",
        "    bert_pretrain_preprocess, lookup_table))\n",
        "\n",
        "next(iter(dataset))"
      ],
      "outputs": [],
      "metadata": {
        "id": "xm4gTLEgjTa3"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Related Tutorials\n",
        "\n",
        "* [Classify text with BERT](https://www.tensorflow.org/text/tutorials/classify_text_with_bert) - A tutorial on how to use a pretrained BERT model to classify text. This is a nice follow up now that you are familiar with how to preprocess the inputs used by the BERT model.\n",
        "\n",
        "* [Tokenizing with TF Text](https://www.tensorflow.org/text/guide/tokenizers) - Tutorial detailing the different types of tokenizers that exist in TF.Text.\n",
        "\n",
        "* [Handling Text with `RaggedTensor`](https://www.tensorflow.org/guide/ragged_tensor) - Detailed guide on how to create, use and manipulate `RaggedTensor`s.\n"
      ],
      "metadata": {
        "id": "FyiMxeEp0m2O"
      }
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "bert_preprocessing_guide.ipynb",
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 2
}